# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch
# which was authored by Yuge Zhang, 2020

import numpy as np
import torch
import torch.nn as nn

import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import kendalltau

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

def to_cuda_float32(obj):
    if torch.is_tensor(obj):
        return obj.cuda().float()
    if isinstance(obj, tuple):
        return tuple(to_cuda_float32(t) for t in obj)
    if isinstance(obj, list):
        return [to_cuda_float32(t) for t in obj]
    if isinstance(obj, dict):
        return {k: to_cuda_float32(v) for k, v in obj.items()}
    if isinstance(obj, (int, float, str)):
        return obj

    raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))


class PredictorModel(object):

    def __init__(self, predictor):
        self.predictor = predictor
        self.mean = 0.0
        self.std = 0.0

    def fit(self,
            train_data,
            batch_size=10,
            epochs=300,
            lr=1e-4,
            wd=1e-3,
        ):

        ytrain = []
        for v in train_data:
            ytrain.append(v["val_acc"])

        self.mean = np.mean(ytrain)
        self.std = np.std(ytrain)
        ytrain_normed = (ytrain - self.mean) / self.std

        for i in range(len(train_data)):
            if train_data[i]["val_acc"] == ytrain[i]:
                train_data[i]["val_acc"] = ytrain_normed[i]
            else:
                raise Exception("not equals")

        data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

        self.predictor.to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.predictor.parameters(), lr=lr, weight_decay=wd)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        self.predictor.train()

        for epoch in range(epochs):
            # meters = AverageMeterGroup()
            lr = optimizer.param_groups[0]["lr"]
            for step, batch in enumerate(data_loader):
                batch = to_cuda_float32(batch)
                target = batch["val_acc"]
                prediction = self.predictor(batch)
                optimizer.zero_grad()
                loss = criterion(prediction, target)
                loss.backward()
                optimizer.step()

                # mse = accuracy_mse(prediction, target)
                # meters.update({"loss": loss.item(), "mse": mse.item()}, n=target.size(0))

            print("epoch: {}, loss: {}".format(epoch, loss.item()))
            # log.info("epoch: {}, loss: {}".format(epoch, loss.item()))
            lr_scheduler.step()

    def predict(self, test_data):
        self.predictor.eval()

        data_loader = DataLoader(test_data)
        with torch.no_grad():
            for step, batch in enumerate(data_loader):
                batch = to_cuda_float32(batch)
                pred_val = self.predictor(batch)

        return pred_val.item() * self.std + self.mean

    def get_test_loss(self, test_data, eval_batch_size=1000, log=None):
        test_data_loader = DataLoader(test_data, batch_size=eval_batch_size)
        criterion = nn.MSELoss()

        self.predictor.eval()

        predict_, target_ = [], []
        with torch.no_grad():
            for step, batch in enumerate(test_data_loader):
                batch = to_cuda_float32(batch)
                target = batch["val_acc"]
                pred_val = self.predictor(batch)
                predict_.append(pred_val.cpu().numpy())
                target_.append(target.cpu().numpy())

                # predict_.append(predict.numpy())
                # target_.append(target.numpy())
                # meters.update({"loss": criterion(predict, target).item(),
                #                "mse": accuracy_mse(predict, target).item()}, n=target.size(0))

                # if (args.eval_print_freq and step % args.eval_print_freq == 0) or \
                #         step % 10 == 0 or step + 1 == len(test_data_loader):
                #     logger.info("Evaluation Step [%d/%d]  %s", step + 1, len(test_data_loader), meters)

        predict_ = np.concatenate(predict_)
        target_ = np.concatenate(target_)
        print("Func-get_test_loss:", "predict_:", predict_[:10], "target_:", target_[:10])
        # logger.info("Kendalltau: %.6f", kendalltau(predict_, target_)[0])

        result = kendalltau(predict_, target_)[0]
        print("Kendalltau: {:.6f}".format(result))
        return result


